import torchvision
from sklearn.model_selection import train_test_split
import torch
from sklearn import metrics, preprocessing
from pmlb import fetch_data
from sklearn.utils import shuffle

# All binary classification problems in PMLB with data points from 150-5000
pmlb_datasets = ['analcatdata_lawsuit',
'australian',
'backache',
'biomed',
'breast_cancer_wisconsin',
'breast_cancer',
'breast_w',
'breast',
'buggyCrx',
'bupa',
'chess',
'churn',
'clean1',
'cleve',
'colic',
'corral',
'credit_a',
'credit_g',
'crx',
'diabetes',
'dis',
'flare',
'GAMETES_Epistasis_2_Way_1000atts_0.4H_EDM_1_EDM_1_1',
'GAMETES_Epistasis_2_Way_20atts_0.1H_EDM_1_1',
'GAMETES_Epistasis_2_Way_20atts_0.4H_EDM_1_1',
'GAMETES_Epistasis_3_Way_20atts_0.2H_EDM_1_1',
'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_50_EDM_2_001',
'GAMETES_Heterogeneity_20atts_1600_Het_0.4_0.2_75_EDM_2_001',
'german',
'glass2',
'haberman',
'heart_c',
'heart_h',
'heart_statlog',
'hepatitis',
'Hill_Valley_with_noise',
'Hill_Valley_without_noise',
'horse_colic',
'house_votes_84',
'hungarian',
'hypothyroid',
'ionosphere',
'irish',
'kr_vs_kp',
'mofn_3_7_10',
'monk1',
'monk2',
'monk3',
'parity5+5',
'pima',
'prnn_crabs',
'prnn_synth',
'profb',
'saheart',
'sonar',
'spambase',
'spect',
'spectf',
'threeOf9',
'tic_tac_toe',
'tokyo1',
'vote',
'wdbc',
'xd6']

# for rbf kernel
SIGMA = 1.0


def get_dataset(dataset_name, train_flag, datadir, exp_dict):
    if dataset_name == "mnist":
        dataset = torchvision.datasets.MNIST(datadir, train=train_flag,
                               download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ]))

    if dataset_name in pmlb_datasets:
        data = fetch_data(dataset_name, local_cache_dir=datadir)
        y = data['target']
        X = data.drop('target', axis=1)
        x = preprocessing.scale(X)

        X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle=False, random_state=0)

        X_train, y_train = shuffle(X_train, y_train, random_state=0)
        y_test = np.array(y_test)

        k_train_X = rbf_kernel(X_train, X_train, SIGMA)
        k_test_X = rbf_kernel(X_test, X_train, SIGMA)

        training_set = torch.utils.data.TensorDataset(torch.tensor(k_train_X, dtype=torch.float),
                                                      torch.tensor(y_train, dtype=torch.float))
        test_set = torch.utils.data.TensorDataset(torch.tensor(k_test_X, dtype=torch.float),
                                                  torch.tensor(y_test, dtype=torch.float))
        if train_flag:
            dataset = training_set
        else:
            dataset = test_set

    return dataset



# ===========================================================
# Helpers


import numpy as np
from torchvision.datasets import MNIST

def load_mnist(data_dir):
    dataset = MNIST(data_dir, train=True, transform=None,
          target_transform=None, download=True)

    X, y = dataset.data.numpy(), dataset.targets.numpy()
    X = X / 255.
    X = X.reshape((X.shape[0], -1))
    return X, y



def rbf_kernel(A, B, sigma):
    distsq = np.square(metrics.pairwise.pairwise_distances(A, B, metric="euclidean"))
    K = np.exp(-1 * distsq/(2*sigma**2))
    return K